"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


create the hierarchic synthetic examples for Fig. 3.

"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))


import numpy as np
from SynthTempNetwork import Individual, SynthTempNetwork
from TemporalNetwork import ContTempNetwork
import pickle

import matplotlib.pyplot as plt

save_dir = '../paper_data/synthtempnet/hierarchic_net_paper'


os.makedirs(save_dir, exist_ok=True)
raise Exception

#%%
#inter_tau = 1
#activ_tau = 2
#num_inter_per_activ = 3


# 3 levels hierarchy 

inter_tau = 10
activ_tau = 10
num_inter_per_activ = 1
t_start = 0
t_end = 1000
n_groups = 27
n_per_group = 3
individuals = []

for g in range(n_groups):

    individuals.extend([Individual(i, inter_distro_scale=inter_tau,
                          activ_distro_scale=activ_tau, group=g) for i in \
            range(g*n_per_group,(g+1)*n_per_group)])


# level1 strength to level 2
r12 = 10
# level1 strength to level 3
r13 = r12*10
# level1 strength to background
# r14 = r13*10

# level 1 auto-strength with b4 = 0
b1 = 1/(1+2/r12+6/r13)
b2 = b1/r12
b3 = b1/r13
b4 = 0


inter_module_probs = np.ones((3,3))*b3
sub_module_probs = np.array([[b1, b2, b2],
                             [b2, b1, b2],
                             [b2, b2, b1]])
    
B0 = np.concatenate((sub_module_probs,
                inter_module_probs,
                inter_module_probs),axis=1)
 
B1 = np.concatenate((inter_module_probs,
                sub_module_probs,
                inter_module_probs),axis=1)

B2 = np.concatenate((inter_module_probs,
                inter_module_probs,
                sub_module_probs),axis=1)

B = np.concatenate((B0,B1,B2),axis=0)

Cinter = np.ones_like(B)*b4

C = np.concatenate((np.concatenate((B,Cinter,Cinter),axis=1),
                    np.concatenate((Cinter,B,Cinter),axis=1),
                    np.concatenate((Cinter,Cinter,B),axis=1),
                    ),axis=0)

inter_group_probs = C

plt.matshow(C)
plt.colorbar()

#%%
with open(os.path.join(save_dir, 'sim_param.pickle'), 'wb') as fopen:
    pickle.dump({'inter_tau' : inter_tau,
                'activ_tau' : activ_tau,
                'num_inter_per_activ' : num_inter_per_activ,
                't_start' : t_start,
                't_end' : t_end,
                'n_groups' : n_groups,
                'n_per_group' : n_per_group,
                'r12' : r12,
                'r13' : r13,
                'inter_group_probs': inter_group_probs}, fopen)

#%%
for k in range(0,10):
    print(k)
    sim = SynthTempNetwork(individuals=individuals, t_start=t_start, t_end=t_end,
                           next_event_method='block_probs',
                           inter_group_probs=inter_group_probs,
                           num_interactions_per_activation=num_inter_per_activ)
    
    sim.run(save_all_states=True, save_dt_states=True, verbose=False)
    

    
    net = ContTempNetwork(source_nodes=sim.indiv_sources,
                          target_nodes=sim.indiv_targets,
                          starting_times=sim.start_times,
                          ending_times=sim.end_times,
                          merge_overlapping_events=True)
        
    net.events_table.head(10)
    
    
    net.save(os.path.join(save_dir, 'hiera_tempnet_paper_{:d}.pickle'.format(k)))
    
#%%
    
